#include "cuda_runtime.h"
#include "curand.h"
#include "cublas_v2.h"

extern "C" {
    #include "activations.h"
    #include "cuda.h"
}

__device__ float lhtan_activate_kernel(float x)
{
    if (x<0) return .001f*x;
    if (x>1) return .001f*(x-1.f) + 1.f;
    return x; 
}

__device__ float lhtan_gradient_kernel(float x)
{
    if(x > 0 && x < 1) return 1;
    return .001;
}

__device__ float hardtan_activate_kernel(float x)
{
    if (x<-1) return -1;
    if (x>1) return 1;
    return x;
}

__device__ float linear_activate_kernel(float x) {return x;}
__device__ float logistic_activate_kernel(float x) {return 1.f/(1.f + expf(-x));}
__device__ float loggy_activate_kernel(float x) {return 2.f/(1.f + expf(-x)) - 1;}
__device__ float relu_activate_kernel(float x) {return x*(x>0);}
__device__ float elu_activate_kernel(float x) {return (x>=0)*x + (x<0)*(expf(x)-1);}
__device__ float selu_activate_kernel(float x) {return (x>=0)*1.0507f*x + (x<0)*1.0507f*1.6732f*(expf(x)-1);}
__device__ float relie_activate_kernel(float x) {return (x>0) ?x : .01f*x;}
__device__ float ramp_activate_kernel(float x) {return (x>0)*x + .1f*x;}
__device__ float leaky_activate_kernel(float x) {return (x>0) ? x : .1f*x;}
__device__ float tanh_activate_kernel(float x) {return (2.f / (1 + expf(-2*x)) - 1);}
__device__ float plse_activate_kernel(float x)
{
    if (x < -4) return .01f * (x + 4);
    if (x > 4) return 0.01f * (x - 4) + 1;
    return 0.125f*x + .5f;
}
__device__ float stair_activate_kernel(float x)
{
    int n = floorf(x);
    if (0 == n%2) return floorf(x/2);
    else return (x-n) + floorf(x/2);
}
__device__ float hardtan_gradient_kernel(float x)
{
    if (x > -1 && x < 1) return 1;
    return 0;
}
__device__ float linear_gradient_kernel(float x) {return 1;}
__device__ float logistic_gradient_kernel(float x) {return (1-x)*x;}
__device__ float loggy_gradient_kernel(float x)
{
    float y = (x+1)/2;
    return 2*(1-y)*y;
}
__device__ float relu_gradient_kernel(float x){return (x>0);}
__device__ float elu_gradient_kernel(float x){return (x>=0) + (x<0)*(x+1);}
__device__ float selu_gradient_kernel(float x){return (x>=0)*1.0507 + (x<0)*(x + 1.0507*1.6732);}
__device__ float relie_gradient_kernel(float x){return (x>0) ? 1 : 0.01f;}
__device__ float ramp_gradient_kernel(float x) {return (x>0) + 0.1f;}
__device__ float leaky_gradient_kernel(float x) {return (x>0)? 1 : .1f;}
__device__ float tanh_gradient_kernel(float x) {return 1-x*x;}
__device__ float plse_gradient_kernel(float x) {return (x<0 || x>1) ? .01f : .125f;}
__device__ float stair_gradient_kernel(float x)
{
    if (x == floorf(x)) return 0;
    return 1;
}
__device__ float activate_kernel(float x, ACTIVATION a)
{
    switch(a)
    {
        case LINEAR:
            return linear_activate_kernel(x);
        case LOGISTIC:
            return logistic_activate_kernel(x);
        case LOGGY:
            return loggy_activate_kernel(x);
        case RELU:
            return relu_activate_kernel(x);
        case ELU:
            return elu_activate_kernel(x);
        case SELU:
            return selu_activate_kernel(x);
        case RELIE:
            return relie_activate_kernel(x);
        case RAMP:
            return ramp_activate_kernel(x);
        case LEAKY:
            return leaky_activate_kernel(x);
        case TANH:
            return tanh_activate_kernel(x);
        case PLSE:
            return plse_activate_kernel(x);
        case STAIR:
            return stair_activate_kernel(x);
        case HARDTAN:
            return hardtan_activate_kernel(x);
        case LHTAN:
            return lhtan_activate_kernel(x);
    }
    return 0;
}

__device__ float gradient_kernel(float x, ACTIVATION a)
{
    switch(a){
        case LINEAR:
            return linear_gradient_kernel(x);
        case LOGISTIC:
            return logistic_gradient_kernel(x);
        case LOGGY:
            return loggy_gradient_kernel(x);
        case RELU:
            return relu_gradient_kernel(x);
        case ELU:
            return elu_gradient_kernel(x);
        case SELU:
            return selu_gradient_kernel(x);
        case RELIE:
            return relie_gradient_kernel(x);
        case RAMP:
            return ramp_gradient_kernel(x);
        case LEAKY:
            return leaky_gradient_kernel(x);
        case TANH:
            return tanh_gradient_kernel(x);
        case PLSE:
            return plse_gradient_kernel(x);
        case STAIR:
            return stair_gradient_kernel(x);
        case HARDTAN:
            return hardtan_gradient_kernel(x);
        case LHTAN:
            return lhtan_gradient_kernel(x);
    }
    return 0;
}

__global__ void binary_gradient_array_kernel(float *x, float *dy, int n, int s, BINARY_ACTIVATION a, float *dx)
{
    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    int i = id % s;
    int b = id / s;
    float x1 = x[b*s + i];
    float x2 = x[b*s + s/2 + i];
    if (id < n)
    {
        float de = dy[id];
        dx[b*s + i] = x2*de;
        dx[b*s + s/2 + i] = x1*de;
    }
}
